{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch 翻译为 ONNX(temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm-book/doc/tutorials/frontend\n"
     ]
    }
   ],
   "source": [
    "%cd ..\n",
    "import set_env\n",
    "from utils.onnx_utils import (\n",
    "    get_input_data_shape_dict,\n",
    "    make_constant_node, get_onnxruntime_output,\n",
    "    get_tvm_output, get_tvm_output_with_vm,\n",
    "    verify_with_ort, verify_with_ort_with_inputs,\n",
    "    quantize_and_verify_with_ort\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import relay"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch 算子测试"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `unsqueeze_constant`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "import onnx\n",
    "from tvm import relay\n",
    "import tempfile\n",
    "\n",
    "class Flatten(nn.Module):\n",
    "    def forward(self, input_):\n",
    "        return input_.view(input_.size(0), -1)\n",
    "\n",
    "with tempfile.NamedTemporaryFile() as f:\n",
    "    file_name = f.name\n",
    "    input_size = (1, 16, 32, 32)\n",
    "    dummy_input = torch.randn(*input_size)\n",
    "    layer = nn.Sequential(Flatten(), nn.Linear(16 * 32 * 32, 64))\n",
    "    torch.onnx.export(layer, dummy_input, file_name, \n",
    "                      input_names=[\"data\"],\n",
    "                      export_params=True)\n",
    "\n",
    "    onnx_model = onnx.load(file_name)\n",
    "    relay.frontend.from_onnx(onnx_model, {\"data\": input_size})"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `embedding_bag`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_aten(target, dev):\n",
    "    \"\"\"test_aten\"\"\"\n",
    "    torch.set_grad_enabled(False)\n",
    "\n",
    "    def _convert_to_onnx(model, inputs):\n",
    "        file_name = \"aten_model.onnx\"\n",
    "        torch.onnx.export(\n",
    "            model,\n",
    "            inputs,\n",
    "            file_name,\n",
    "            export_params=True,\n",
    "            verbose=False,\n",
    "            opset_version=10,\n",
    "            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,\n",
    "        )\n",
    "        onnx_model = onnx.load(file_name)\n",
    "        return onnx_model\n",
    "\n",
    "    def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None):\n",
    "        dummy_data = torch.randint(0, num_embedding - 1, data_shape)\n",
    "        tvm_inputs = [dummy_data.numpy()]\n",
    "        model = torch.nn.EmbeddingBag(num_embedding, embedding_dim)\n",
    "        onnx_model = _convert_to_onnx(model, dummy_data)\n",
    "        torch_out = model(dummy_data)\n",
    "        tvm_out = get_tvm_output_with_vm(\n",
    "            onnx_model,\n",
    "            tvm_inputs,\n",
    "            freeze_params=True,\n",
    "            target=target,\n",
    "            dev=dev,\n",
    "        )\n",
    "        np.testing.assert_allclose(torch_out.numpy(), tvm_out, atol=5e-7)\n",
    "\n",
    "    verify_embedding_bag(10, 3, [2, 10])\n",
    "    verify_embedding_bag(32, 2, [3, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_aten(\"llvm\", tvm.cpu())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `index_put_slice`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class IndexPutModel(torch.nn.Module):\n",
    "    def __init__(self, indices, values, accumulate):\n",
    "        super().__init__()\n",
    "        self.indices = indices\n",
    "        self.values = values\n",
    "        self.accumulate = accumulate\n",
    "\n",
    "    def forward(self, x):\n",
    "        return x.index_put(self.indices, self.values, self.accumulate)\n",
    "\n",
    "def _convert_to_onnx(model, dummy_data):\n",
    "    file_name = \"aten_model.onnx\"\n",
    "    torch.onnx.export(\n",
    "        model,\n",
    "        dummy_data,\n",
    "        file_name,\n",
    "        export_params=True,\n",
    "        verbose=False,\n",
    "        opset_version=11,\n",
    "        operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,\n",
    "    )\n",
    "    onnx_model = onnx.load(file_name)\n",
    "    return onnx_model\n",
    "\n",
    "def verify_index_put(data_shape, indices, accumulate):\n",
    "    target = \"llvm\"\n",
    "    dev = tvm.cpu()\n",
    "    dummy_data = torch.ones(data_shape)\n",
    "    tvm_inputs = [dummy_data.numpy()]\n",
    "    values = torch.rand(indices[0].size())\n",
    "    model = IndexPutModel(indices, values, accumulate)\n",
    "    onnx_model = _convert_to_onnx(model, dummy_data)\n",
    "    torch_out = model(dummy_data)\n",
    "\n",
    "    tvm_out = get_tvm_output_with_vm(onnx_model, tvm_inputs, target, dev, freeze_params=True)\n",
    "    tvm.testing.assert_allclose(torch_out.numpy(), tvm_out)\n",
    "\n",
    "shape = (3, 5)\n",
    "xidx = torch.tensor([0, 1, 2, 2])\n",
    "yidx = torch.tensor([0, 1, 3, 4])\n",
    "verify_index_put(shape, [xidx, yidx], True)\n",
    "\n",
    "shape = (3, 5, 3)\n",
    "xidx = torch.tensor([0, 1, 2, 2, 0])\n",
    "yidx = torch.tensor([0, 1, 3, 4, 0])\n",
    "zidx = torch.tensor([0, 1, 1, 2, 0])\n",
    "verify_index_put(shape, [xidx, yidx, zidx], False)\n",
    "\n",
    "def verify_index_put_slice(data_shape, value_shape, accumulate):\n",
    "    dummy_data = torch.ones(data_shape)\n",
    "    tvm_inputs = [dummy_data.numpy()]\n",
    "    indices = []\n",
    "    index_shape = [1] * len(value_shape)\n",
    "    index_shape[0] = -1\n",
    "    for _, v_shape in enumerate(value_shape):\n",
    "        indices.append(torch.arange(0, v_shape).reshape(tuple(index_shape)))\n",
    "        index_shape.pop()\n",
    "    values = torch.rand(value_shape)\n",
    "\n",
    "    model = IndexPutModel(indices, values, accumulate)\n",
    "    onnx_model = _convert_to_onnx(model, dummy_data)\n",
    "    torch_out = model(dummy_data)\n",
    "\n",
    "    target = \"llvm\"\n",
    "    dev = tvm.cpu()\n",
    "    tvm_out = get_tvm_output_with_vm(onnx_model, tvm_inputs, target, dev, freeze_params=True)\n",
    "    np.testing.assert_allclose(torch_out.numpy(), tvm_out)\n",
    "\n",
    "verify_index_put_slice((3, 3), (2, 2), False)\n",
    "verify_index_put_slice((2, 3, 4), (1, 2, 3), True)\n",
    "verify_index_put_slice((2, 3, 4, 5), (1, 2, 3, 1), False)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `torchvision` 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torchvision\n",
    "import onnx\n",
    "\n",
    "def check_torch_conversion(model, input_size, target, dev):\n",
    "    dummy_input = torch.randn(*input_size)\n",
    "    file_name = f\"{model.__name__}.onnx\"\n",
    "    # 设置 `verbose=True` 显示更多信息\n",
    "    torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False)\n",
    "    onnx_model = onnx.load(file_name)\n",
    "    input_data = np.random.uniform(size=input_size).astype(\"float32\")\n",
    "    verify_with_ort_with_inputs(\n",
    "        onnx_model, [input_data], apply_softmax=True, target=target, dev=dev\n",
    "    )\n",
    "\n",
    "# def test_alexnet():\n",
    "# Torch's ONNX export does not support the adaptive pooling used by AlexNet?\n",
    "# check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))\n",
    "\n",
    "# Torch's ONNX export does not support the adaptive pooling used by vgg16?\n",
    "# def test_vgg16():\n",
    "#     check_torch_conversion(torchvision.models.vgg16, (1,3,224,224))\n",
    "\n",
    "# TODO(@jroesch): Update Torch + ONNX to support this import.\n",
    "# def test_squeezenet():\n",
    "#     # Torch's ONNX export does not support the max pooling used by Squezenet\n",
    "#     check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224))\n",
    "\n",
    "# TODO(@jroesch): Update Torch + ONNX to support this import.\n",
    "# def test_googlenet():\n",
    "#     check_torch_conversion(torchvision.models.googlenet, (1,3,224,224))\n",
    "\n",
    "# TODO(@jroesch): Update Torch + ONNX to support this import.\n",
    "# def test_shufflenetv2():\n",
    "#     check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224))\n",
    "\n",
    "@tvm.testing.parametrize_targets\n",
    "def test_densenet(target, dev):\n",
    "    check_torch_conversion(torchvision.models.densenet161, (1, 3, 224, 224), target, dev)\n",
    "\n",
    "\n",
    "@tvm.testing.parametrize_targets\n",
    "def test_inception(target, dev):\n",
    "    check_torch_conversion(torchvision.models.inception_v3, (1, 3, 224, 224), target, dev)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `resnet18`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = \"llvm\"\n",
    "dev = tvm.cpu()\n",
    "check_torch_conversion(torchvision.models.resnet18, (1, 3, 224, 224), target, dev)\n",
    "# check_torch_conversion(torchvision.models.resnet101, (1,3,224,224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
